Computing lab: Non-parametric modelling with Gaussian Processes and approximations
Hello all!
The objective of this lab is to gain experience and familiarity in implementing computationally scalable non-parametric Bayesian normal linear regression models with Stan.
Let us revisit one more time the wildfire data from our previous computing lab, in which we calculated the point prevalence of US land on fire per calendar month:
Previously, we noted seasonal fluctuations in wildfire prevalence and modelled these with a linear year effect and further independent effects for each month. In this lab, we will model the monthly effects with random functions, and give these random function a Gaussian process prior with squared exponential kernel, or alternatively a computationally more scalable Hilbert-Space Gaussian process prior with squared exponential kernel.
Let us load the data and consider log wildfire prevalence as before:
# load wildfire prevalence from lab 1
file <- file.path(data.dir,'US_wild_fires_prevalence.rds')
wfp <- readRDS(file)
land_area <- wfp[1, LANDAREA]
# select variables needed to answer study question
wfp <- subset(wfp, select = c(FIRE_YEAR, MONTH, PREVALENCE))
str(wfp)
#> Classes 'data.table' and 'data.frame': 289 obs. of 3 variables:
#> $ FIRE_YEAR : int 2005 2005 2005 2005 2005 2006 2006 2006 2007 2007 ...
#> $ MONTH : chr "05" "06" "07" "03" ...
#> $ PREVALENCE: num 2.09e-04 1.83e-03 1.48e-03 7.32e-05 9.06e-04 ...
#> - attr(*, ".internal.selfref")=<externalptr>
# set up log prevalence as response variable for a normal regression model
wfp[, LOG_PREV := wfp[, log(PREVALENCE)] ]
# set up year covariate
wfp[, FIRE_YEAR_2 := log(FIRE_YEAR)]
wfp[, FIRE_YEAR_2 := (FIRE_YEAR_2 - mean(FIRE_YEAR_2))/sd(FIRE_YEAR_2) ]To model month effects non-parametrically, we need to setup an index that associates the value of our random function \(f\) for a particular month with each observation. We will also need to define standardised inputs for our random function that is supposed to capture month effects:
# index that associates the value of our random function to each observation
set(wfp, NULL, 'MONTH', wfp[, as.integer(MONTH)])
# define standardised inputs, so off-the-shelf GP priors can be used
wfp[, INPUT := MONTH/12]
wfp[, INPUT_2 := (INPUT - mean(INPUT))/sd(INPUT)]
# add observation index for easy post-processing
setkey(wfp, FIRE_YEAR, MONTH)
wfp[, IDX := seq_len(nrow(wfp))]Let us denote by \(y_i\) the log wildfire prevalence in the \(i\)th observation.
Previously, we modeled \(y_i\) with \[\begin{align*} & y_i \sim \text{Normal}(\mu_i, \sigma^2) \\ & \mu_i = \beta_0 + \beta_1 X_{i1} + \cdots + \beta_{13} X_{i13} \\ & \beta_0 \sim \text{Normal}(0, 100) \\ & \beta_j \sim \text{Normal}(0, 1) \\ & \sigma \sim \text{Half-Cauchy}(0,1) \end{align*}\] where \(X_{ij}\), \(j=1,\dotsc,12\) are binary indicators that evaluate to \(1\) if the \(i\)th observation corresponds to the \(j\)th month in a year and \(X_{i13}\) is the standardized log year associated with the \(i\)th observation.
Now, we model \(y_i\) with \[\begin{align*} & y_i \sim \text{Normal}(\mu_i, \sigma^2) \\ & \mu_i = \beta_0 + \beta_1 X_{i1} + f(\text{month}_i) \\ & \beta_0 \sim \text{Normal}(0, 100) \\ & \beta_1 \sim \text{Normal}(0, 1) \\ & f \sim \text{GP}(\alpha, \rho) \\ & \alpha \sim \text{Half-Cauchy}(0, 1) \\ & \rho \sim \text{Inv-Gamma}(5, 1) \\ & \sigma \sim \text{Half-Cauchy}(0,1) \end{align*}\] where \(X_{i1}\) is the standardized log year associated with the \(i\)th observation and so \(\beta_1\) models a linear annual effect, and \(f\) is a random function that is evaluated at monthly inputs and so captures month effects non-parametrically. The random function is give a zero-mean GP prior with squared exponential kernel with GP variance \(\alpha\) and lengthscale \(\rho\). The hyperparameters \(\alpha\), \(\rho\) are given default priors that are suitable for a standardised input domain \([0,1]\).
Here is the Stan model file.
Note that the joint distribution of \(f\) evaluated at a finite set of inputs is just a multivariate normal, and so we can straightforwardly generate samples from \(f\) through linear transformation of iid standard normal random variables (through the line \(f = L_f * z\)).
Note also that the variance-covariance matrix must not contain zeros, and so \(f\) can only be evaluated at unique inputs. I chose to model month effects through random functions to make this point clear. Be sure you understand how the values of \(f\) at the unique, standardised inputs that represent months are mapped back to each observation in the line mu = beta0 + X * beta + f[map_unique_inputs_to_obs];.
# compile Stan model
wfp_gp_model <- rstan::stan_model("stan_models/wfp_gp.stan", model_name = "wfp_gp")Let us fit the model to our wildfire data. The map from \(f\) evaluated at monthly inputs is just the months \(1, 2, \dotsc\) associated with each observation:
# define data in format needed for model specification
stan_data <- list()
stan_data$N <- nrow(wfp)
stan_data$X <- unname(as.matrix(subset(wfp, select = FIRE_YEAR_2)))
stan_data$K <- ncol(stan_data$X)
stan_data$y <- wfp$LOG_PREV
stan_data$NI <- 12
stan_data$inputs_standardised <- unique(sort(wfp$INPUT_2))
stan_data$map_unique_inputs_to_obs <- wfp$MONTH
# sample from joint posterior of the Hello World model with cmdstanr
# I initialized the sampler at values to help avoid -Inf likelihood evaluations
wfp_gp_fit <- rstan::sampling(wfp_gp_model,
data = stan_data,
seed = 123,
chains = 2,
cores = 2,
warmup = 500,
iter = 2000,
init = list(
list(beta0 = -9, gp_sigma = 1, gp_lengthscale = .5, sigma = 1),
list(beta0 = -9, gp_sigma = 1, gp_lengthscale = .5, sigma = 1)
))
#> Warning: There were 4 divergent transitions after warmup. See
#> https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#> to find out why this is a problem and how to eliminate them.
#> Warning: Examine the pairs() plot to diagnose sampling problems
# save output to RDS
saveRDS(wfp_gp_fit, file = file.path(out.dir, "wfp_gp_fit.rds"))GP lengthscales, that are very close to zero or are very large relative to the input domain, can often be associated with divergent transitions in the HMC sampler. With the above prior choice on \(\rho\), you should find few if any divergent transitions, and can safely proceed if this is so.
As always we inspect convergence and mixing:
# load output from RDS
wfp_gp_fit <- readRDS(file.path(out.dir, "wfp_gp_fit.rds"))
# ====== Define a helper function to perform diagnostics =====
make_diagnostic_summary <- function(stan_fit, pars) {
# Extract the samples of the parameters
po_draws <- rstan::extract(stan_fit, pars = pars)
# Coherce them into a format that is easier to work with
po_draws <- posterior::as_draws(po_draws)
# Make a summary table
sum_tbl <- summary(
po_draws,
posterior::default_summary_measures(),
extra_quantiles = ~posterior::quantile2(., probs = c(.0275, .975)),
posterior::default_convergence_measures()
)
return(sum_tbl)
}
# =============================================================
# A vector containing the names of parameters we want to extract
model_pars <- c("beta0", "beta", "sigma", "gp_lengthscale", "gp_sigma")
sum_tbl <- make_diagnostic_summary(wfp_gp_fit, model_pars)
# print summaries
kableExtra::kbl(sum_tbl, digits = 3) %>%
kableExtra::kable_styling(bootstrap_options = c("striped", "hover", "condensed"), font_size = 12)| variable | mean | median | sd | mad | q5 | q95 | q2.75 | q97.5 | rhat | ess_bulk | ess_tail |
|---|---|---|---|---|---|---|---|---|---|---|---|
| beta0 | -9.628 | -9.592 | 0.792 | 0.670 | -10.980 | -8.420 | -11.352 | -8.126 | 1.000 | 3007.884 | 2825.849 |
| beta | 0.161 | 0.162 | 0.049 | 0.050 | 0.079 | 0.240 | 0.063 | 0.256 | 1.000 | 2984.128 | 2844.480 |
| sigma | 0.839 | 0.837 | 0.036 | 0.035 | 0.782 | 0.901 | 0.774 | 0.912 | 1.001 | 2842.879 | 2894.833 |
| gp_lengthscale | 0.460 | 0.460 | 0.115 | 0.118 | 0.278 | 0.648 | 0.246 | 0.689 | 1.002 | 2614.311 | 2883.776 |
| gp_sigma | 1.428 | 1.305 | 0.503 | 0.391 | 0.860 | 2.412 | 0.798 | 2.767 | 1.000 | 2794.974 | 2578.639 |
There are many parameters, so let us explore only the trace of the model parameter with lowest effective sample size. It is common that posterior GP lengthscales are hard to mix, and we can observe some autocorrelation, though the extent of autocorrelation remains really very low and we would not be concerned about the levels seen:
# plot traces of parameter with smallest ess_bulk
# ===== Define a function to help us plot trace plots =====
plot_min_ess_trace <- function(stan_fit, sum_tbl) {
var_ess_min <- sum_tbl$variable[which.min(sum_tbl$ess_bulk)]
# often helpful to plot the log posterior density too
po_draws <- rstan::extract(stan_fit,
pars = c("lp__", var_ess_min),
inc_warmup = TRUE)
po_draws <- posterior::as_draws(po_draws)
p <- bayesplot:::mcmc_trace(po_draws,
pars = c("lp__", var_ess_min),
n_warmup = 500,
facet_args = list(nrow = 2))
return(p)
}
# ==========================================================
p <- plot_min_ess_trace(wfp_gp_fit, sum_tbl)
ggsave(file = file.path(out.dir,'wfp_gp_worst_trace.png'), p, width = 12, height = 10)Let us practice post-processing key aspects of our non-parametric Bayesian model.
A very common task is to obtain posterior median estimates and 95% credible intervals from target quantities. Here, we will focus as target quantity on the median wildfire prevalence \(\exp(\mu(t))\) evaluated at our time inputs. Remember the golden rule shown below:
# ====== Define a function to help us with extracting the draws =====
extract_posterior_draws <- function(stan_fit, par) {
# Extract draws for a specific parameter
po_draws <- rstan::extract(stan_fit,
pars = par,
permuted = TRUE,
inc_warmup = FALSE)
# Coerse into a format that is easier to work with
po_draws <- posterior::as_draws_df(po_draws[[par]])
dt_po <- as.data.table(po_draws)
setnames(dt_po, names(po_draws), gsub("\\.","", names(po_draws)))
# extract indices of mu as column in data.table
dt_po <- data.table::melt(dt_po,
id.vars = c('draw','chain','iteration'))
set(dt_po, NULL, 'variable', dt_po[, as.character(variable)])
set(dt_po, NULL, 'IDX', dt_po[, as.integer(gsub('(.*)\\[(.*)\\]','\\2',variable))])
set(dt_po, NULL, 'variable', dt_po[, gsub('(.*)\\[(.*)\\]','\\1',variable)])
return(dt_po)
}
# extract Monte Carlo samples of joint posterior and transformed parameter mu
dt_po <- extract_posterior_draws(wfp_gp_fit, "mu")
# golden rule: first transform, then summarize!
dt_po[, median_prevalence := exp(value)]
dt_po_sum <- dt_po[, list(
V = quantile(median_prevalence, probs = c(0.5, 0.025, 0.975)),
STAT = c('M','CL','CU')
), by = c('IDX')]
dt_po_sum <- dcast.data.table(dt_po_sum, IDX ~ STAT, value.var = 'V')
dt_po_sum <- merge(wfp, dt_po_sum, by = 'IDX')
dt_po_sum[, DATE := as.Date(paste0(FIRE_YEAR,'-',MONTH,'-15'))]# plot posterior median point estimates and 95% CRI
p <- ggplot(dt_po_sum, aes(x = DATE)) +
geom_point(aes(y = M), colour = 'darkorange', shape = 1) +
geom_linerange(aes(ymin = CL, ymax = CU), colour = 'darkorange', linewidth = 0.4) +
geom_line(aes(y = PREVALENCE)) +
geom_point(aes(y = PREVALENCE)) +
scale_x_date(date_breaks = '6 months') +
scale_y_continuous(labels = scales::percent) +
labs(x = '', y = 'US land area on fire') +
theme_bw() +
theme(axis.text.x = element_text(angle = 45,vjust = 1,hjust = 1))
ggsave(file = file.path(out.dir,'US_wild_fires_prevalence_gp_linerange.png'), p, w = 12, h = 6)Another common task is to plot a few realizations of random functions. We will again focus on the random function that corresponds to median wildfire prevalence \(\exp(\mu(t))\), that is induced through the random function of month effect plus linear year effects. Note the minor variations in the magnitude and shape of the random functions:
# plot sample of 4 random functions evaluated at monthly inputs
tmp <- sort(sample(max(dt_po$draw), 4))
tmp <- data.table(draw = tmp)
tmp <- merge(tmp, dt_po, by = 'draw')
tmp[, median_prevalence := exp(value)]
tmp <- merge(wfp, tmp, by = 'IDX')
tmp[, DATE := as.Date(paste0(FIRE_YEAR,'-',MONTH,'-15'))]p <- ggplot(tmp, aes(x = DATE)) +
geom_line(aes(y = median_prevalence), colour = 'darkorange') +
geom_line(aes(y = PREVALENCE), colour = 'black') +
scale_x_date(date_breaks = '6 months') +
scale_y_continuous(labels = scales::percent) +
labs(x = '', y = 'US land area on fire') +
facet_wrap(~draw, ncol = 2) +
theme_bw() +
theme(axis.text.x = element_text(angle = 45,vjust = 1,hjust = 1))
ggsave(file = file.path(out.dir,'US_wild_fires_prevalence_gp_random_functions.png'), p, w = 12, h = 12)To help understand how the model works and/or inspect potential coding issues, it can also be useful to inspect the estimated shape of the random functions evaluated on the standardized inputs:
# plot the fitted GP
tmp <- sort(sample(max(dt_po$draw), 100))
tmp <- data.table(draw = tmp)
tmp[, SAMPLE_IDX := seq_len(nrow(tmp))]
dt_po_f <- extract_posterior_draws(wfp_gp_fit, "f")
tmp <- merge(tmp, dt_po_f, by = 'draw')
setnames(tmp, 'IDX', 'MONTH')
tmp <- merge(tmp, unique(subset(wfp, select = c(MONTH, INPUT_2))), by = 'MONTH')
set(tmp, NULL, 'draw', tmp[, factor(draw)])
p <- ggplot(tmp, aes(x = INPUT_2, colour = draw)) +
geom_line(aes(y = value)) +
labs(x = 'standardised inputs', y = 'value of HSGP') +
theme_bw()
ggsave(file = file.path(out.dir,'gp_random_functions.png'), p, w = 12, h = 6)Let us now model \(y_i\) using computationally scalable Hilbert-Space GP approximations:
\[\begin{align*} & y_i \sim \text{Normal}(\mu_i, \sigma^2) \\ & \mu_i = \beta_0 + \beta_1 X_{i1} + f(\text{month}_i) \\ & \beta_0 \sim \text{Normal}(0, 100) \\ & \beta_1 \sim \text{Normal}(0, 1) \\ & f \sim \text{HSGP}(\alpha, \rho) \\ & \alpha \sim \text{Half-Cauchy}(0, 1) \\ & \rho \sim \text{Inv-Gamma}(5, 1) \\ & \sigma \sim \text{Half-Cauchy}(0,1) \end{align*}\] where \(X_{i1}\) is the standardized log year associated with the \(i\)th observation and so \(\beta_1\) models a linear annual effect, and \(f\) is a random function that is evaluated at monthly inputs and so captures month effects non-parametrically. The random function is given a zero-mean HSGP prior with squared exponential kernel with GP variance \(\alpha\) and lengthscale \(\rho\). The hyper-parameters \(\alpha\), \(\rho\) are given default priors that are suitable for a standardised input domain \([0,1]\).
Here is the Stan model file.
Note how the HSGP basis functions are precomputed at the standardised inputs once and for all in the transformed data block, and how the HSGP approximation is constructed through relatively cheap matrix multiplications.
# compile Stan model
wfp_hsgp_model <- rstan::stan_model("stan_models/wfp_hsgp.stan", model_name = "wfp_hsgp")For the purposes of this lab with specify the HSGP boundary factor to \(1.2\) and the number of HSGP basis functions to \(30\), but see the paper “Practical Hilbert space approximate Bayesian Gaussian processes for probabilistic programming” for more details.
In this application, the variance-covariance matrix at the inputs is just a \(12 \times 12\) matrix, and so you won’t notice any computational speedups when using HSGPs:
# define data in format needed for model specification
stan_data <- list()
stan_data$N <- nrow(wfp)
stan_data$X <- unname(as.matrix(subset(wfp, select = FIRE_YEAR_2)))
stan_data$K <- ncol(stan_data$X)
stan_data$y <- wfp$LOG_PREV
stan_data$hsgp_c <- 1.2
stan_data$hsgp_M <- 30
stan_data$NI <- 12
stan_data$inputs_standardised <- unique(sort(wfp$INPUT_2))
stan_data$map_unique_inputs_to_obs <- wfp$MONTH
# sample from joint posterior of the Hello World model with cmdstanr
wfp_hsgp_fit <- rstan::sampling(wfp_hsgp_model,
data = stan_data,
seed = 123,
chains = 2,
cores = 2,
warmup = 500,
iter = 2000,
init = list(
list(beta0 = -9, gp_sigma = 1, gp_lengthscale = .5, sigma = 1),
list(beta0 = -9, gp_sigma = 1, gp_lengthscale = .5, sigma = 1)
))
#> Warning: There were 46 divergent transitions after warmup. See
#> https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
#> to find out why this is a problem and how to eliminate them.
#> Warning: Examine the pairs() plot to diagnose sampling problems
saveRDS(wfp_hsgp_fit, file = file.path(out.dir, "wfp_hsgp_fit.rds"))You will once more note a few divergent transitions. We will ignore these for now, however if you are interested in learning more, consider the Stan primer on divergent transitions.
As always we inspect convergence and mixing:
# load output from RDS
wfp_hsgp_fit <- readRDS(file.path(out.dir, "wfp_hsgp_fit.rds"))
model_pars <- c("beta0", "beta", "sigma", "gp_lengthscale", "gp_sigma")
sum_tbl <- make_diagnostic_summary(wfp_hsgp_fit, model_pars)
# print summaries
kableExtra::kbl(sum_tbl, digits = 3) %>%
kableExtra::kable_styling(bootstrap_options = c("striped", "hover", "condensed"), font_size = 12)| variable | mean | median | sd | mad | q5 | q95 | q2.75 | q97.5 | rhat | ess_bulk | ess_tail |
|---|---|---|---|---|---|---|---|---|---|---|---|
| beta0 | -10.491 | -10.302 | 1.105 | 1.290 | -12.294 | -8.905 | -12.509 | -8.678 | 1.004 | 3084.373 | 2964.866 |
| beta | 0.160 | 0.160 | 0.050 | 0.050 | 0.079 | 0.242 | 0.067 | 0.258 | 1.000 | 3128.845 | 2842.497 |
| sigma | 0.839 | 0.838 | 0.035 | 0.034 | 0.782 | 0.900 | 0.773 | 0.912 | 1.001 | 2896.961 | 2856.204 |
| gp_lengthscale | 0.461 | 0.435 | 0.137 | 0.148 | 0.269 | 0.702 | 0.249 | 0.732 | 1.004 | 2971.402 | 2808.257 |
| gp_sigma | 1.698 | 1.523 | 0.691 | 0.610 | 0.885 | 3.109 | 0.828 | 3.444 | 1.004 | 3078.690 | 2642.855 |
# plot traces of parameter with smallest ess_bulk
# often helpful to plot the log posterior density too
p <- plot_min_ess_trace(wfp_hsgp_fit, sum_tbl)
ggsave(file = file.path(out.dir,'wfp_hsgp_model_worst_trace.png'), p, w = 12, h = 10)The code for summarising Monte Carlo samples of target quantities to posterior median estimates and 95% credible intervals, and for plotting samples of random functions is exactly the same as before. The main point is that the HSGP priors induce qualitatively very similar statistical behaviors as the GP priors, and in larger dimensions are substantially computationally faster to evaluate. Get familiar with the code, so you can use it for your own purposes.
# extract Monte Carlo samples of joint posterior and transformed parameter mu
dt_po_mu <- extract_posterior_draws(wfp_hsgp_fit, "mu")
# summarize posterior median prevalence, and merge with data and inputs
dt_po_mu[, median_prevalence := exp(value)]
dt_po_mu_sum <- dt_po_mu[, list(
V = quantile(median_prevalence, probs = c(0.5, 0.025, 0.975)),
STAT = c('M','CL','CU')
), by = c('IDX')]
dt_po_mu_sum <- dcast.data.table(dt_po_mu_sum, IDX ~ STAT, value.var = 'V')
dt_po_mu_sum <- merge(wfp, dt_po_mu_sum, by = 'IDX')
dt_po_mu_sum[, DATE := as.Date(paste0(FIRE_YEAR,'-',MONTH,'-15'))]
# plot posterior median point estimates and 95% CRI
p <- ggplot(dt_po_mu_sum, aes(x = DATE)) +
geom_point(aes(y = M), colour = 'darkorange', shape = 1) +
geom_linerange(aes(ymin = CL, ymax = CU), colour = 'darkorange', linewidth = 0.4) +
geom_line(aes(y = PREVALENCE)) +
geom_point(aes(y = PREVALENCE)) +
scale_x_date(date_breaks = '6 months') +
scale_y_continuous(labels = scales::percent) +
labs(x = '', y = 'US land area on fire') +
theme_bw() +
theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
ggsave(file = file.path(out.dir,'US_wild_fires_prevalence_hsgp_linerange.png'), p, w = 12, h = 6)# plot sample of 4 random functions evaluated at monthly inputs
tmp <- sort(sample(max(dt_po_mu$draw), 4))
tmp <- data.table(draw = tmp)
tmp <- merge(tmp, dt_po_mu, by = 'draw')
tmp[, median_prevalence := exp(value)]
tmp <- merge(wfp, tmp, by = 'IDX')
tmp[, DATE := as.Date(paste0(FIRE_YEAR,'-',MONTH,'-15'))]
p <- ggplot(tmp, aes(x = DATE)) +
geom_line(aes(y = median_prevalence), colour = 'darkorange') +
geom_line(aes(y = PREVALENCE), colour = 'black') +
scale_x_date(date_breaks = '6 months') +
scale_y_continuous(labels = scales::percent) +
labs(x = '', y = 'US land area on fire') +
facet_wrap(~draw, ncol = 2) +
theme_bw() +
theme(axis.text.x = element_text(angle = 45,vjust = 1,hjust = 1))
ggsave(file = file.path(out.dir,'US_wild_fires_prevalence_hsgp_random_functions.png'), p, w = 12, h = 12)# plot the fitted GP
tmp <- sort(sample(max(dt_po_f$draw), 100))
tmp <- data.table(draw = tmp)
tmp[, SAMPLE_IDX := seq_len(nrow(tmp))]
tmp <- merge(tmp, dt_po_f, by = 'draw')
setnames(tmp, 'IDX', 'MONTH')
tmp <- merge(tmp, unique(subset(wfp, select = c(MONTH, INPUT_2))), by = 'MONTH')
set(tmp, NULL, 'draw', tmp[, factor(draw)])
p <- ggplot(tmp, aes(x = INPUT_2, colour = draw)) +
geom_line(aes(y = value)) +
labs(x = 'standardised inputs', y = 'value of HSGP') +
theme_bw()
ggsave(file = file.path(out.dir,'hsgp_random_functions.png'), p, w = 12, h = 6)Our previous plots of the estimated median wildfire prevalence clearly show that our model fails to reproduce the precise seasonal features in wildfire prevalence. Wildfire prevalence is more irregular from year to year as we currently model, and the occasional explosive peaks seen in the data are not captured by our model.
How could the model be improved?
For example, how about adding iid annual effects? See if you can give it a go.
# TODO define year index and add year index to wfp
# TODO define data in format needed for model specification
# stan_data <-
# TODO use the previous Stan model on this new data set
# wfp_hsgp_2_fit <- rstan::sampling(wfp_hsgp_model,
# data = stan_data,
# seed = 123,
# chains = 2,
# cores = 2,
# warmup = 500,
# iter = 5000,
# refresh = 500,
# init = list(list(beta0 = -9, gp_sigma = 1, gp_lengthscale = .5, sigma = 1),
# list(beta0 = -9, gp_sigma = 1, gp_lengthscale = .5, sigma = 1)))
# TODO reproduce earlier analyses